import torch
import torch.nn as nn
from tensordict import TensorDict
from torch import Tensor
from functools import partial
from typing import Tuple
from einops import rearrange
from typing import Literal

class Actor(nn.Module):
    def __init__(self,
                 *args,
                 encoder_layer: nn.Module,
                 middle_layer: nn.Module,
                 decoder_layer: nn.Module,
                 blocks: nn.Module,
                 norm: nn.Module,
                 method: str = "mlp",
                 device: str = "cpu",
                 seq_len: int = 64,
                 **kwargs
                 ):
        super(Actor, self).__init__()

        self.device = device
        self.seq_len = seq_len
        self.encoder_layer = encoder_layer.to(device)
        self.middle_layer = middle_layer.to(device)
        self.decoder_layer = decoder_layer.to(device)
        self.blocks = blocks.to(device)
        self.norm = norm.to(device)
        self.method = method


    def forward_encoder(self, features: TensorDict):

        if self.method == "mlp":
            stem_layer = self.encoder_layer['stem_layer'].to(features.device)
            x = stem_layer(features)

            x = self.blocks(x) # extract features
            x = self.norm(x) # normalize

            x = rearrange(x, '... d n -> ... n d') # rearrange dimensions
            reduce_seq_layer = self.middle_layer['reduce_seq_layer']
            x = reduce_seq_layer(x) # reduce sequence dimension
            x = rearrange(x, '... n d -> ... (d n)') # flatten (..., sequence, embedding) to (..., sequence * embedding)

            reduce_embed_layer = self.middle_layer['reduce_embed_layer']
            x = reduce_embed_layer(x) # reduce embedding dimension

        elif self.method == "transformer":
            stem_layer = self.encoder_layer['stem_layer']
            x = stem_layer(features)

            indices_layer = self.encoder_layer['indices']
            indices_embedding = indices_layer(torch.arange(self.seq_len).to(x.device)) # positiona embedding
            x = x + indices_embedding # add position embedding

            if len(x.shape) > 3:
                b, n = x.shape[:2]
                x = rearrange(x, 'b n ... -> (b n) ...', b=b, n=n)
                x = self.blocks(x) # extract features
                x = rearrange(x, '(b n) ... -> b n ...', b=b, n=n)
            else:
                x = self.blocks(x)

            x = self.norm(x)

            x = x[..., -1, :] # get last token

        return x

    def decoder(self, x: Tensor):
        x = self.decoder_layer(x)
        return x

    def forward(self, x: TensorDict):
        latent = self.forward_encoder(x)
        pred = self.decoder(latent)
        return pred